# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import numpy as np

from paddlenlp.datasets import MapDataset


def create_dataloader(dataset,
                      mode='train',
                      batch_size=1,
                      batchify_fn=None,
                      trans_fn=None):
    if trans_fn:
        dataset = dataset.map(trans_fn)

    shuffle = True if mode == 'train' else False
    if mode == 'train':
        # 定义分布式 Sampler: 自动对训练数据进行切分，支持多卡并行训练
        batch_sampler = paddle.io.DistributedBatchSampler(
            dataset, batch_size=batch_size, shuffle=shuffle)
    else:
        # 如果使用单卡进行评估，采用 paddle.io.BatchSampler 即可# 定义 dev_data_loader
        batch_sampler = paddle.io.BatchSampler(
            dataset, batch_size=batch_size, shuffle=shuffle)
    # 基于 train_ds 定义 train_data_loader# 因为我们使用了分布式的 DistributedBatchSampler, train_data_loader 会自动对训练数据进行切分
    return paddle.io.DataLoader(
        dataset=dataset,
        batch_sampler=batch_sampler,
        collate_fn=batchify_fn,
        return_list=True)


def read_text_pair(data_path, is_test=False):
    """Reads data."""
    with open(data_path, 'r', encoding='utf-8') as f:
        for line in f:
            data = line.rstrip().split("\t")
            if is_test == False:
                if len(data) != 3:
                    continue
                #yield函数相当于return，不同的是如果使用next函数，会从yield的下一条语句开始执行，直到return/yield
                yield {'query1': data[0], 'query2': data[1], 'label': data[2]}
            else:
                if len(data) != 2:
                    continue
#for example, {喜欢打篮球的男生喜欢什么样的女生 爱打篮球的男生喜欢什么样的女生 1} is converted to
#{'query': '喜欢打篮球的男生喜欢什么样的女生', 'title': '爱打篮球的男生喜欢什么样的女生', 'label': 1}
                yield {'query1': data[0], 'query2': data[1]}


#将1条明文数据的query、title拼接起来，根据预训练模型的tokenizer将明文转化为ID数据
def convert_example(example, tokenizer, max_seq_length=512, is_test=False):

    query, title = example["query1"], example["query2"]

    encoded_inputs = tokenizer(
        text=query, text_pair=title, max_seq_len=max_seq_length)

    input_ids = encoded_inputs["input_ids"]
    token_type_ids = encoded_inputs["token_type_ids"]

    if not is_test:
        label = np.array([example["label"]], dtype="int64")
        return input_ids, token_type_ids, label
    #在预测或者评估阶段，不返回label字段
    else:
        return input_ids, token_type_ids

    ### 对训练集的第 1 条数据进行转换可转化为这种形式
    # input_ids, token_type_ids, label = convert_example(train_ds[0], tokenizer)
    # print(input_ids)
    # [1, 692, 811, 445, 2001, 497, 5, 654, 21, 692, 811, 614, 356, 314, 5, 291, 21, 2, 329, 445, 2001, 497, 5, 654, 21,
    #  692, 811, 614, 356, 314, 5, 291, 21, 2]
    # print(token_type_ids)
    # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
    # print(label)
    # [1]